#!/usr/bin/env python3
# -*- coding: utf-8 -*-

import matplotlib.pyplot as plt 
from scipy.stats import poisson
import numpy as np
import time
import random
import pickle
import pandas as pd

def rv_gen(theta):
    while True:
        xi = poisson.rvs(theta)
        if xi <= M_c:
            return xi

def indicator(x):
    if x == 0:
        return 0.5
    if x > 0:
        return 1
    return 0

def cost_func(s, a, xi, p = 6, h = 4):
    return h * max(0,s + a - xi) + p * max(0,xi-a-s)

def g(s,a,xi):
    # project to state space
    x = s + a - xi
    if x in S:
        return x
    if x > max(S):
        return max(S)
    return min(S)

def product_for_zero_times_inf(x,y):
    if x == 0:
        return 0
    if y == 0:
        return 0
    return x * y

def find_alpha_function(u_vector):
    alpha_function = {}
    index_opt_a = {}
    t = N - 1
    alpha_function_t = np.inf * np.ones([len(Theta), len(S), len(A)])
    for index_theta, theta in enumerate(Theta):
        for index_s, s in enumerate(S):                    
            for index_a, a in enumerate(A):
                alpha_function_t[index_theta, index_s, index_a] = 1 / (1 - alpha) * max(sum([cost_func(s, a, xi) * trun_poisson_dictionary[(theta, xi)] for xi in range(M_c + 1)]) - u_vector[t], 0) + u_vector[t]
    alpha_function[t] = alpha_function_t
    
    for t in range(N-2,-1,-1):
        index_opt_a_next = 100 * np.ones([len(Theta), len(S), len(A)], dtype=np.int8)            
        alpha_function_t = np.inf * np.ones([len(Theta), len(S), len(A)])
        
        for index_theta, theta in enumerate(Theta):
            for index_s, s in enumerate(S):                    
                for index_a, a in enumerate(A):
                    a_next_index = np.argmin(np.sum([alpha_function[t+1][index_theta, S.index(g(s,a,xi)), :] * trun_poisson_dictionary[(theta, xi)] for xi in range(M_c + 1)],0))
                    index_opt_a_next[index_theta, index_s, index_a] = a_next_index
                    alpha_function_t[index_theta, index_s, index_a] = 1 / (1 - alpha) * max(sum([(cost_func(s, a, xi) + alpha_function[t+1][index_theta, S.index(g(s,a,xi)), a_next_index]) * trun_poisson_dictionary[(theta, xi)] for xi in range(M_c + 1)]) - u_vector[t], 0) + u_vector[t]
        alpha_function[t] = alpha_function_t
        index_opt_a[t+1] = index_opt_a_next
        
    Q_0 = [sum([product_for_zero_times_inf(mu_0[theta], alpha_function[0][index_theta, index_s_0, index_a]) for index_theta, theta in enumerate(Theta)]) for index_a, a in enumerate(A)]
    index_opt_a[0] = np.argmin(Q_0)
    V_0 = min(Q_0)  
    return (alpha_function, index_opt_a, V_0)

def find_u_gradient(u_vector, alpha_function, index_opt_a):
    partial_alpha_0_alpha ={}
    partial_alpha_0_u ={}
    for index_theta, theta in enumerate(Theta):
        index_s = index_s_0
        index_a = index_opt_a[0]
        partial_alpha_0_alpha[(0,index_theta)] = 1
        for t in range(N):
            s = S[index_s]
            a = A[index_a]
            xi = rv_gen(theta)
            index_s_next = S.index(g(s,a,xi))
            if t < N-1:
                index_a_next = index_opt_a[t+1][index_theta, index_s, index_a]
                # partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(cost_func(s, a, xi) - u_vector[t] + alpha_function[t+1][index_theta, index_s_next, index_a_next])
                partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(sum([(cost_func(s, a, xi) + alpha_function[t+1][index_theta, S.index(g(s,a,xi)), index_a_next]) * trun_poisson_dictionary[(theta, xi)] for xi in range(M_c + 1)])  - u_vector[t])
                partial_alpha_0_alpha[(t+1,index_theta)] = partial_alpha_0_alpha[(t,index_theta)] * (1 - partial_alpha_t_u_t)
                # pass the state and action to next stage
                index_s = index_s_next
                index_a = index_a_next
            else:
                partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(sum([cost_func(s, a, xi)  * trun_poisson_dictionary[(theta, xi)] for xi in range(M_c + 1)])  - u_vector[t])
                # partial_alpha_t_u_t = 1 - 1 / (1 - alpha) * indicator(cost_func(s, a, xi) - u_vector[t])
            partial_alpha_0_u[(t,index_theta)] = partial_alpha_0_alpha[(t,index_theta)] * partial_alpha_t_u_t
            
    u_gradient = [sum([ mu_0[theta] * partial_alpha_0_u[(t,index_theta)] for index_theta, theta in enumerate(Theta)]) for t in range(N)]
    return u_gradient  

def SGD(u_vector, K = 1000, parm1 = 1, parm2 = 1000, SGD_iter = 20):
    V_opt = np.inf
    # u_opt = u_vector.copy()
    for i in range(K):
        eta = parm1 / (parm2 + i**1)
        alpha_function, index_opt_a, V_0 = find_alpha_function(u_vector)
        # print(u_vector,V_0)
        # print(i)
        if V_opt > V_0:
            V_opt = V_0
            u_opt = u_vector.copy()
            alpha_function_opt = alpha_function.copy()
        for SGD in range(SGD_iter):   
            u_gradient = find_u_gradient(u_vector, alpha_function, index_opt_a)
            # update u vector using SGD
            u_vector = u_vector - eta * np.array(u_gradient)
    return (u_opt, V_opt, alpha_function_opt)

def mu_space(n,h,weight,weight_round =4):
    # n is the number of 
    if round(weight,weight_round) < 0:
        return []
    if n == 1:
        return [[round(weight,weight_round)]]
    l = []
    for i in np.arange(0,weight+h,h):
        l += [[i]+ll for ll in mu_space(n-1,h,weight-i)]
    return l
    
def g_2(mu, xi):
    mu_next = {}
    for theta in Theta:
        mu_next[theta] = mu[theta] * trun_poisson_dictionary[(theta, xi)]
    mass = sum(mu_next.values())
    for theta in Theta:
        mu_next[theta] = mu_next[theta] / mass
    #projection into Mu_space
    #l2 norm or KL divergence
    # can be speeded up
    s = np.inf
    for m in Mu:
        ss = sum([(mu_next[theta] - m[theta])**2 for theta in Theta])
        if ss < s:
            mu_proj = m
            s = ss
    return mu_proj

def alpha_policy_evaluation(alpha_function):
    V_DP = {}
    pi = {}
    V_DP[N] = {}
    for s in S:
        for index_mu, mu in enumerate(Mu):
            V_DP[N][(s,index_mu)] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        pi[t] = {}
        for s in S:
            for index_mu, mu in enumerate(Mu):
                V_opt = float('inf')
                index_s = S.index(s)
                for index_a, a in enumerate(A):
                    V = sum([alpha_function[t][index_theta, index_s, index_a] * mu[theta] for index_theta, theta in enumerate(Theta)])
                    if V < V_opt:
                        V_opt = V
                        a_opt = a
                pi[t][(s,index_mu)] = a_opt
                value = sum([(cost_func(s, a_opt, xi) + V_DP[t+1][(g(s, a_opt, xi), posterior_transition_matrix[(index_mu, xi)])] ) * trun_poisson_dictionary[(theta_c, xi)] for xi in range(M_c + 1)])
                V_DP[t][(s,index_mu)] = value
            
            
    # print(V_DP[0][(s_0,mu_list.index(mu_0))])
    return V_DP[0][(s_0,Mu.index(mu_0))]



# Nominal
def DP_mle(theta_mle):
    V_DP = {}
    pi = {}
    V_DP[N] = {}
    for s in S:
        V_DP[N][s] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        pi[t] = {}
        for s in S:                   
            V_opt = float('inf')
            for a in A:
                V = sum([(cost_func(s, a, xi) + V_DP[t+1][g(s, a, xi)] ) * trun_poisson_dictionary[(theta_mle, xi)] for xi in range(M_c + 1)])
                if V < V_opt:
                    V_opt = V
                    a_opt = a
            V_DP[t][s] = V_opt
            pi[t][s] = a_opt
    return (pi, V_DP[0][s_0])

def DP_mle_policy_evaluation(pi_mle, theta_c):
    V_DP = {}
    V_DP[N] = {}
    for s in S:
        V_DP[N][s] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        for s in S:                   
            a = pi_mle[t][s]
            V_DP[t][s] = sum([(cost_func(s, a, xi) + V_DP[t+1][g(s, a, xi)] ) * trun_poisson_dictionary[(theta_c, xi)] for xi in range(M_c + 1)])
            
    return V_DP[0][s_0]


def rho_metric(V_distribution, metric, q):
    # VaR
    if metric == 'VaR':
        s = 0
        for V in sorted(list(V_distribution.keys())):
            s += V_distribution[V]
            if s >= q:
                return V
    # CVaR
    if metric == 'CVaR':
        s = 0
        numerator = 0
        denominator = 0
        for V in sorted(list(V_distribution.keys())):
            s += V_distribution[V]
            if s >= q:
                numerator += V * V_distribution[V]
                denominator += V_distribution[V]
        return numerator / denominator    
    # # expectation
    if metric == 'mean':
        return sum([V * V_distribution[V] for V in list(V_distribution.keys())])


def DP_BRMDP(q, metric = 'CVaR'):
    V_DP = {}
    pi = {}
    V_DP[N] = {}
    for s in S:
        for index_mu, mu in enumerate(Mu):
            V_DP[N][(s,index_mu)] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        pi[t] = {}
        for s in S:
            for index_mu, mu in enumerate(Mu):
                V_opt = float('inf')
                for index_a, a in enumerate(A):
                    V_distribution = {}
                    for theta in Theta:
                        value = 0
                        for xi in range(M_c + 1):
                            index_mu_next = posterior_transition_matrix[(index_mu, xi)]
                            s_next = g(s,a,xi)
                            # (cost_func(s, a, xi) + V_DP[t+1][(g(s, a, xi), posterior_transition_matrix[(index_mu, xi)])] )
                            value += (cost_func(s, a, xi) +V_DP[t+1][(s_next,index_mu_next)])  * trun_poisson_dictionary[(theta, xi)]
                        try:
                            V_distribution[value] += mu[theta]
                        except:
                            V_distribution[value] = mu[theta]
                    V = rho_metric(V_distribution, metric, q)
                    if V < V_opt:
                        V_opt = V
                        a_opt = a
                V_DP[t][(s,index_mu)] = V_opt
                pi[t][(s,index_mu)] = a_opt
            
            
    # print(V_DP[0][(s_0,mu_list.index(mu_0))])
    return (pi, V_DP[0][(s_0,Mu.index(mu_0))])

def DP_BRMDP_evaluation(pi, theta_c):
    V_DP = {}
    V_DP[N] = {}
    for s in S:
        for index_mu, mu in enumerate(Mu):
            V_DP[N][(s,index_mu)] = 0
    for t in range(N-1,-1,-1):
        V_DP[t] = {}
        for s in S:
            for index_mu, mu in enumerate(Mu):
                a = pi[t][(s,index_mu)]
                value = 0
                for xi in range(M_c + 1):
                    index_mu_next = posterior_transition_matrix[(index_mu, xi)]
                    s_next = g(s,a,xi)
                    value += (cost_func(s, a, xi) +V_DP[t+1][(s_next,index_mu_next)])  * trun_poisson_dictionary[(theta_c, xi)]
                
                V_DP[t][(s,index_mu)] = value

    return V_DP[0][(s_0,Mu.index(mu_0))]


def prior_update(mu_t, data):
    mu = mu_t.copy()
    for theta in Theta:
        for xi in data:
            mu[theta] = mu[theta] * trun_poisson_dictionary[(theta, xi)]
    mass = sum(mu.values())
    for theta in Theta:
        mu[theta] = mu[theta] / mass
    #projection into Mu_space
    #l2 norm or KL divergence
    # can be speeded up
    s = np.inf
    for m in Mu:
        ss = sum([(mu[theta] - m[theta])**2 for theta in Theta])
        if ss < s:
            mu_proj = m
            s = ss
    return mu_proj


def DP_DRMDP():
    V_opt = -np.inf
    for theta in Theta:
        pi_theta, V_0_theta = DP_mle(theta)
        if V_opt < V_0_theta:
            V_opt = V_0_theta
            pi_opt = pi_theta.copy()
    return (pi_opt, V_opt)

def alpha_policy_evaluation_trajectory(alpha_function):
    V_trajectory = 0
    s = s_0
    mu = mu_0.copy()
    index_mu = Mu.index(mu)
    for t in range(N):
        V_opt = float('inf')
        index_s = S.index(s)
        for index_a, a in enumerate(A):
            if s >= a:
                V = sum([alpha_function[t][index_theta, index_s, index_a] * mu[theta] for index_theta, theta in enumerate(Theta)])
                if V < V_opt:
                    V_opt = V
                    a_opt = a
        xi = rv_gen(theta_c)
        V_trajectory += cost_func(s,a_opt, xi)
        # print(V_trajectory )
        s = g(s,a_opt, xi)
        index_mu = posterior_transition_matrix[(index_mu, xi)]
        mu = Mu[index_mu].copy()
 
    # print(V_DP[0][(s_0,mu_list.index(mu_0))])
    return V_trajectory

def DP_mle_policy_evaluation_trajectory(pi, theta_c):
    V_trajectory = 0
    s = s_0
    for t in range(N):
        a_opt = pi[t][s]
        xi = rv_gen(theta_c)
        V_trajectory += cost_func(s,a_opt, xi)
        # print(V_trajectory, xi )
        s = g(s,a_opt, xi)

    return V_trajectory

def DP_BRMDP_evaluation_trajectory(pi, theta_c):
    V_trajectory = 0
    s = s_0
    mu = mu_0.copy()
    index_mu = Mu.index(mu)
    for t in range(N):
        a_opt = pi[t][(s,index_mu)]
        xi = rv_gen(theta_c)
        V_trajectory += cost_func(s,a_opt, xi)
        # print(V_trajectory )
        s = g(s,a_opt, xi)
        index_mu = posterior_transition_matrix[(index_mu, xi)]
        mu = Mu[index_mu].copy()
 
    # print(V_DP[0][(s_0,mu_list.index(mu_0))])
    return V_trajectory

print('Start the program.')
N = 6 # time horizon
s_0 = 5

M = 15
M_c = 20
S = list(range(0,M))
A = list(range(0,M))
index_s_0 = S.index(s_0)
Theta = [4,6,8,10,12,14,16]
theta_c = 12

trun_poisson_dictionary = {}
for theta in Theta:
    weight = poisson.cdf(M_c,theta)
    for xi in range(M_c + 1):
        trun_poisson_dictionary[(theta, xi)] = poisson.pmf(xi,theta)/weight
mu_0 = {}
for theta in Theta:
    mu_0[theta] = 1/len(Theta)
Mu_space = mu_space(len(Theta),0.1,1)
Mu = []
for x in Mu_space:
    mu = {}
    for i in range(len(Theta)):
        theta = Theta[i]
        mu[theta] = x[i]
    Mu += [mu]

posterior_transition_matrix = {}
for index_mu, mu in enumerate(Mu):
    for  xi in range(M_c + 1):
        mu_next = g_2(mu, xi)
        posterior_transition_matrix[(index_mu, xi)] = Mu.index(mu_next)
        
print('Preparation done.')

# comparison:
value_approx_BRMDP_stats = {}
value_mle_stats = {}
value_BRMDP_stats = {}
times_approx_BRMDP = []
times_BRMDP = []
times_Nominal = []
times_BRMDP_1 = []
times_DRMDP = []
pi_DRMDP, V_0_DRMDP_star = DP_DRMDP()
value_DRMDP_stats = {}

alpha = 0.4

value_approx_BRMDP = []
value_mle = []
value_BRMDP = []
value_DRMDP = []

np.random.seed(1)
for i in range(100):
    start_time = time.time()
    print('iteration: ',i)
    data = [rv_gen(theta_c) for i in range(10)]
    mu_0 = {}
    for theta in Theta:
        mu_0[theta] = 1/len(Theta)
    mu_0 = prior_update(mu_0, data)
    u_vector = 10 * np.ones(N)
    t_before = time.time()
    u_vector, V_0_approx_BRMDP_star, alpha_function = SGD(u_vector, K = 200, parm1 = 10, parm2 = 1)
    t_after = time.time()
    times_approx_BRMDP += [t_after - t_before]
    V_0_approx_BRMDP = alpha_policy_evaluation(alpha_function)
    value_approx_BRMDP += [V_0_approx_BRMDP]
    theta_mle_og = np.mean(data)
    diff = np.inf
    for theta in Theta:
        if diff > abs(theta - theta_mle_og):
            diff = abs(theta - theta_mle_og)
            theta_mle = theta
    t_before = time.time()
    pi_mle, V_0_mle_star = DP_mle(theta_mle)
    t_after = time.time()
    times_Nominal += [t_after - t_before]
    V_0_mle = DP_mle_policy_evaluation(pi_mle, theta_c)
    value_mle += [V_0_mle]
    t_before = time.time()
    pi_BRMDP, V_0_BRMDP_star = DP_BRMDP(q = alpha)
    t_after = time.time()
    times_BRMDP += [t_after - t_before]
    V_0_BRMDP = DP_BRMDP_evaluation(pi_BRMDP, theta_c)
    value_BRMDP += [V_0_BRMDP]
    V_0_DRMDP = DP_mle_policy_evaluation(pi_DRMDP, theta_c)
    value_DRMDP += [V_0_DRMDP]
    print('approx DP V_0: ', V_0_approx_BRMDP, 'Nominal V_0: ', V_0_mle, 'Exact BRMDP V_0: ', V_0_BRMDP, 'theta_Nominal: ', theta_mle)
    print('loop time:', time.time() - start_time)
            
value_approx_BRMDP_stats[(theta_c, alpha)] = value_approx_BRMDP.copy()
value_mle_stats[(theta_c, alpha)] = value_mle.copy()
value_BRMDP_stats[(theta_c, alpha)] = value_BRMDP.copy()
value_DRMDP_stats[(theta_c, alpha)] = value_DRMDP.copy()

# plt.hist([value_approx_BRMDP, value_BRMDP, value_mle], density=False, bins = 20, label = [r'CVaR BR-MDP (approx, $\alpha$ = %0.2f)' %(alpha),r'CVaR BR-MDP (exact, $\alpha$ = %0.2f)' %(alpha),'Nominal'])
# plt.legend()
# plt.title(r'Histogram: $\theta_c$ = %0.2f' % (theta_c))
# plt.ylabel('frequency')
# plt.xlabel('true performance')
# plt.savefig('Histogram: theta = %f, alpha = %f.png' % (theta_c, alpha), dpi=300)
# plt.show()
plt.hist([value_DRMDP, value_BRMDP, value_mle], density=False, bins = 20, label = ['DR-MDP',r'CVaR BR-MDP (exact, $\alpha$ = %0.2f)' %(alpha),'Nominal'])
plt.legend()
plt.title(r'Histogram: $\theta_c$ = %0.2f' % (theta_c))
plt.ylabel('frequency')
plt.xlabel('true performance')
plt.savefig('Histogram_no_approx_BRMDP: theta = %f, $\alpha$ = %f.png' % (theta_c, alpha), dpi=300)
plt.show()


times_BRMDP_1 = []
times_DRMDP = []
value_BRMDP = []
value_DRMDP = []

np.random.seed(1)
for i in range(100):
    print('second iteration: ', i)
    data = [rv_gen(theta_c) for i in range(10)]
    mu_0 = {}
    for theta in Theta:
        mu_0[theta] = 1/len(Theta)
    mu_0 = prior_update(mu_0, data)
    t_before = time.time()
    pi_BRMDP, V_0_BRMDP_star = DP_BRMDP(q = 0.9999)
    t_after = time.time()
    times_BRMDP_1 += [t_after - t_before]
    V_0_BRMDP = DP_BRMDP_evaluation(pi_BRMDP, theta_c)
    value_BRMDP += [V_0_BRMDP]
    t_before = time.time()
    pi_DRMDP, V_0_DRMDP_star = DP_DRMDP()
    t_after = time.time()
    times_DRMDP += [t_after - t_before]
    V_0_DRMDP = DP_mle_policy_evaluation(pi_DRMDP, theta_c)
    value_DRMDP += [V_0_DRMDP]
    
value_BRMDP_stats[(theta_c, 1)] = value_BRMDP.copy()
value_DRMDP_stats[(theta_c, 1)] = value_DRMDP.copy()
time_stats = {}
time_stats['DRMDP'] = times_DRMDP
time_stats['Exact BRMDP alpha = 1'] = times_BRMDP_1
time_stats['Exact BRMDP'] = times_BRMDP
time_stats['Approximate BRMDP'] = times_approx_BRMDP
time_stats['Nominal'] = times_Nominal

